{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "### facenet-pytorch LFW evaluation\n",
    "This notebook demonstrates how to evaluate performance against the LFW dataset."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training, extract_face\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler\n",
    "from torchvision import datasets, transforms\n",
    "import numpy as np\n",
    "import os"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "data_dir = 'data/lfw/lfw'\n",
    "pairs_path = 'data/lfw/pairs.txt'\n",
    "\n",
    "batch_size = 16\n",
    "epochs = 15\n",
    "workers = 0 if os.name == 'nt' else 8"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on device: cuda:0\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print('Running on device: {}'.format(device))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "mtcnn = MTCNN(\n",
    "    image_size=160,\n",
    "    margin=14,\n",
    "    device=device,\n",
    "    selection_method='center_weighted_size'\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "# Define the data loader for the input set of images\n",
    "orig_img_ds = datasets.ImageFolder(data_dir, transform=None)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "\n",
    "# overwrites class labels in dataset with path so path can be used for saving output in mtcnn batches\n",
    "orig_img_ds.samples = [\n",
    "    (p, p)\n",
    "    for p, _ in orig_img_ds.samples\n",
    "]\n",
    "\n",
    "loader = DataLoader(\n",
    "    orig_img_ds,\n",
    "    num_workers=workers,\n",
    "    batch_size=batch_size,\n",
    "    collate_fn=training.collate_pil\n",
    ")\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crop_paths = []\n",
    "box_probs = []\n",
    "\n",
    "for i, (x, b_paths) in enumerate(loader):\n",
    "    crops = [p.replace(data_dir, data_dir + '_cropped') for p in b_paths]\n",
    "    mtcnn(x, save_path=crops)\n",
    "    crop_paths.extend(crops)\n",
    "    print('\\rBatch {} of {}'.format(i + 1, len(loader)), end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "# Remove mtcnn to reduce GPU memory usage\n",
    "del mtcnn\n",
    "torch.cuda.empty_cache()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "# create dataset and data loaders from cropped images output from MTCNN\n",
    "\n",
    "trans = transforms.Compose([\n",
    "    np.float32,\n",
    "    transforms.ToTensor(),\n",
    "    fixed_image_standardization\n",
    "])\n",
    "\n",
    "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
    "\n",
    "embed_loader = DataLoader(\n",
    "    dataset,\n",
    "    num_workers=workers,\n",
    "    batch_size=batch_size,\n",
    "    sampler=SequentialSampler(dataset)\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [],
   "source": [
    "# Load pretrained resnet model\n",
    "resnet = InceptionResnetV1(\n",
    "    classify=False,\n",
    "    pretrained='vggface2'\n",
    ").to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "classes = []\n",
    "embeddings = []\n",
    "resnet.eval()\n",
    "with torch.no_grad():\n",
    "    for xb, yb in embed_loader:\n",
    "        xb = xb.to(device)\n",
    "        b_embeddings = resnet(xb)\n",
    "        b_embeddings = b_embeddings.to('cpu').numpy()\n",
    "        classes.extend(yb.numpy())\n",
    "        embeddings.extend(b_embeddings)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "embeddings_dict = dict(zip(crop_paths,embeddings))\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "#### Evaluate embeddings by using distance metrics to perform verification on the official LFW test set.\n",
    "\n",
    "The functions in the next block are copy pasted from `facenet.src.lfw`. Unfortunately that module has an absolute import from `facenet`, so can't be imported from the submodule\n",
    "\n",
    "added functionality to return false positive and false negatives"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold\n",
    "from scipy import interpolate\n",
    "\n",
    "# LFW functions taken from David Sandberg's FaceNet implementation\n",
    "def distance(embeddings1, embeddings2, distance_metric=0):\n",
    "    if distance_metric==0:\n",
    "        # Euclidian distance\n",
    "        diff = np.subtract(embeddings1, embeddings2)\n",
    "        dist = np.sum(np.square(diff),1)\n",
    "    elif distance_metric==1:\n",
    "        # Distance based on cosine similarity\n",
    "        dot = np.sum(np.multiply(embeddings1, embeddings2), axis=1)\n",
    "        norm = np.linalg.norm(embeddings1, axis=1) * np.linalg.norm(embeddings2, axis=1)\n",
    "        similarity = dot / norm\n",
    "        dist = np.arccos(similarity) / math.pi\n",
    "    else:\n",
    "        raise 'Undefined distance metric %d' % distance_metric\n",
    "\n",
    "    return dist\n",
    "\n",
    "def calculate_roc(thresholds, embeddings1, embeddings2, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
    "    assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
    "    assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
    "    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
    "    nrof_thresholds = len(thresholds)\n",
    "    k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
    "\n",
    "    tprs = np.zeros((nrof_folds,nrof_thresholds))\n",
    "    fprs = np.zeros((nrof_folds,nrof_thresholds))\n",
    "    accuracy = np.zeros((nrof_folds))\n",
    "\n",
    "    is_false_positive = []\n",
    "    is_false_negative = []\n",
    "\n",
    "    indices = np.arange(nrof_pairs)\n",
    "\n",
    "    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
    "        if subtract_mean:\n",
    "            mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
    "        else:\n",
    "          mean = 0.0\n",
    "        dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
    "\n",
    "        # Find the best threshold for the fold\n",
    "        acc_train = np.zeros((nrof_thresholds))\n",
    "        for threshold_idx, threshold in enumerate(thresholds):\n",
    "            _, _, acc_train[threshold_idx], _ ,_ = calculate_accuracy(threshold, dist[train_set], actual_issame[train_set])\n",
    "        best_threshold_index = np.argmax(acc_train)\n",
    "        for threshold_idx, threshold in enumerate(thresholds):\n",
    "            tprs[fold_idx,threshold_idx], fprs[fold_idx,threshold_idx], _, _, _ = calculate_accuracy(threshold, dist[test_set], actual_issame[test_set])\n",
    "        _, _, accuracy[fold_idx], is_fp, is_fn = calculate_accuracy(thresholds[best_threshold_index], dist[test_set], actual_issame[test_set])\n",
    "\n",
    "        tpr = np.mean(tprs,0)\n",
    "        fpr = np.mean(fprs,0)\n",
    "        is_false_positive.extend(is_fp)\n",
    "        is_false_negative.extend(is_fn)\n",
    "\n",
    "    return tpr, fpr, accuracy, is_false_positive, is_false_negative\n",
    "\n",
    "def calculate_accuracy(threshold, dist, actual_issame):\n",
    "    predict_issame = np.less(dist, threshold)\n",
    "    tp = np.sum(np.logical_and(predict_issame, actual_issame))\n",
    "    fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
    "    tn = np.sum(np.logical_and(np.logical_not(predict_issame), np.logical_not(actual_issame)))\n",
    "    fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))\n",
    "\n",
    "    is_fp = np.logical_and(predict_issame, np.logical_not(actual_issame))\n",
    "    is_fn = np.logical_and(np.logical_not(predict_issame), actual_issame)\n",
    "\n",
    "    tpr = 0 if (tp+fn==0) else float(tp) / float(tp+fn)\n",
    "    fpr = 0 if (fp+tn==0) else float(fp) / float(fp+tn)\n",
    "    acc = float(tp+tn)/dist.size\n",
    "    return tpr, fpr, acc, is_fp, is_fn\n",
    "\n",
    "def calculate_val(thresholds, embeddings1, embeddings2, actual_issame, far_target, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
    "    assert(embeddings1.shape[0] == embeddings2.shape[0])\n",
    "    assert(embeddings1.shape[1] == embeddings2.shape[1])\n",
    "    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])\n",
    "    nrof_thresholds = len(thresholds)\n",
    "    k_fold = KFold(n_splits=nrof_folds, shuffle=False)\n",
    "\n",
    "    val = np.zeros(nrof_folds)\n",
    "    far = np.zeros(nrof_folds)\n",
    "\n",
    "    indices = np.arange(nrof_pairs)\n",
    "\n",
    "    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):\n",
    "        if subtract_mean:\n",
    "            mean = np.mean(np.concatenate([embeddings1[train_set], embeddings2[train_set]]), axis=0)\n",
    "        else:\n",
    "          mean = 0.0\n",
    "        dist = distance(embeddings1-mean, embeddings2-mean, distance_metric)\n",
    "\n",
    "        # Find the threshold that gives FAR = far_target\n",
    "        far_train = np.zeros(nrof_thresholds)\n",
    "        for threshold_idx, threshold in enumerate(thresholds):\n",
    "            _, far_train[threshold_idx] = calculate_val_far(threshold, dist[train_set], actual_issame[train_set])\n",
    "        if np.max(far_train)>=far_target:\n",
    "            f = interpolate.interp1d(far_train, thresholds, kind='slinear')\n",
    "            threshold = f(far_target)\n",
    "        else:\n",
    "            threshold = 0.0\n",
    "\n",
    "        val[fold_idx], far[fold_idx] = calculate_val_far(threshold, dist[test_set], actual_issame[test_set])\n",
    "\n",
    "    val_mean = np.mean(val)\n",
    "    far_mean = np.mean(far)\n",
    "    val_std = np.std(val)\n",
    "    return val_mean, val_std, far_mean\n",
    "\n",
    "def calculate_val_far(threshold, dist, actual_issame):\n",
    "    predict_issame = np.less(dist, threshold)\n",
    "    true_accept = np.sum(np.logical_and(predict_issame, actual_issame))\n",
    "    false_accept = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))\n",
    "    n_same = np.sum(actual_issame)\n",
    "    n_diff = np.sum(np.logical_not(actual_issame))\n",
    "    val = float(true_accept) / float(n_same)\n",
    "    far = float(false_accept) / float(n_diff)\n",
    "    return val, far\n",
    "\n",
    "\n",
    "\n",
    "def evaluate(embeddings, actual_issame, nrof_folds=10, distance_metric=0, subtract_mean=False):\n",
    "    # Calculate evaluation metrics\n",
    "    thresholds = np.arange(0, 4, 0.01)\n",
    "    embeddings1 = embeddings[0::2]\n",
    "    embeddings2 = embeddings[1::2]\n",
    "    tpr, fpr, accuracy, fp, fn  = calculate_roc(thresholds, embeddings1, embeddings2,\n",
    "        np.asarray(actual_issame), nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
    "    thresholds = np.arange(0, 4, 0.001)\n",
    "    val, val_std, far = calculate_val(thresholds, embeddings1, embeddings2,\n",
    "        np.asarray(actual_issame), 1e-3, nrof_folds=nrof_folds, distance_metric=distance_metric, subtract_mean=subtract_mean)\n",
    "    return tpr, fpr, accuracy, val, val_std, far, fp, fn\n",
    "\n",
    "def add_extension(path):\n",
    "    if os.path.exists(path+'.jpg'):\n",
    "        return path+'.jpg'\n",
    "    elif os.path.exists(path+'.png'):\n",
    "        return path+'.png'\n",
    "    else:\n",
    "        raise RuntimeError('No file \"%s\" with extension png or jpg.' % path)\n",
    "\n",
    "def get_paths(lfw_dir, pairs):\n",
    "    nrof_skipped_pairs = 0\n",
    "    path_list = []\n",
    "    issame_list = []\n",
    "    for pair in pairs:\n",
    "        if len(pair) == 3:\n",
    "            path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
    "            path1 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])))\n",
    "            issame = True\n",
    "        elif len(pair) == 4:\n",
    "            path0 = add_extension(os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])))\n",
    "            path1 = add_extension(os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])))\n",
    "            issame = False\n",
    "        if os.path.exists(path0) and os.path.exists(path1):    # Only add the pair if both paths exist\n",
    "            path_list += (path0,path1)\n",
    "            issame_list.append(issame)\n",
    "        else:\n",
    "            nrof_skipped_pairs += 1\n",
    "    if nrof_skipped_pairs>0:\n",
    "        print('Skipped %d image pairs' % nrof_skipped_pairs)\n",
    "\n",
    "    return path_list, issame_list\n",
    "\n",
    "def read_pairs(pairs_filename):\n",
    "    pairs = []\n",
    "    with open(pairs_filename, 'r') as f:\n",
    "        for line in f.readlines()[1:]:\n",
    "            pair = line.strip().split()\n",
    "            pairs.append(pair)\n",
    "    return np.array(pairs, dtype=object)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "pairs = read_pairs(pairs_path)\n",
    "path_list, issame_list = get_paths(data_dir+'_cropped', pairs)\n",
    "embeddings = np.array([embeddings_dict[path] for path in path_list])\n",
    "\n",
    "tpr, fpr, accuracy, val, val_std, far, fp, fn = evaluate(embeddings, issame_list)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.995      0.995      0.99166667 0.99       0.99       0.99666667\n",
      " 0.99       0.995      0.99666667 0.99666667]\n"
     ]
    },
    {
     "data": {
      "text/plain": "0.9936666666666666"
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(accuracy)\n",
    "np.mean(accuracy)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}